-
Notifications
You must be signed in to change notification settings - Fork 557
Add support for float mask to aten::masked_fill #1337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d80ee26
to
09e6850
Compare
09e6850
to
53fbe7b
Compare
Can you split the changes into 3 PRs? They are all quite independent from one another. It would also make the commit titles a lot more descriptive, since each PR could get as a title the description you have in the bullet points above. |
53fbe7b
to
b300854
Compare
I've split the PRs, this one is for masked_fill now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just have a small change request, but other than that it LGTM
003853f
to
f73493e
Compare
@@ -906,6 +906,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( | |||
|
|||
Value input = payloadArgs[0]; | |||
Value mask = payloadArgs[1]; | |||
if (mask.getType().isa<mlir::FloatType>()) | |||
mask = b.create<arith::ConstantOp>(loc, b.getBoolAttr(false)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I missed this the first time I reviewed your changes. Why is mask
being set to false
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems to be expected behavior. I was casting to Int1
at first, but further testing shows that it seems to treat all floats as false. I haven't found anything in the documentation about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be a bug upstream. I would actually expect float
mask to result in a runtime error, since this is the behavior that aten.masked_select
has:
Can you file a bug upstream for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, the bug report is up here. I'm going to leave this as-is for now in case it's expected behavior but I'll add an assert if it isn't.
Rewrite the deprecated ONNX Scatter operation (https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scatter) using the equivalent ScatterElements operation.
This adds a few small changes that are needed for OPT support, namely:
Support for
aten::view
when the output shape is statically knownFolding away
torch::type_as
when both arguments are the same typeSupport for
torch::masked_fill
when the mask is a float type